In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import medmnist
from medmnist import INFO, Evaluator
from torchvision import transforms
from tqdm import tqdm
from sklearn.metrics import accuracy_score

# Constants
BATCH_SIZE = 64
LATENT_DIM = 64
EPOCHS = 20
LEARNING_RATE = 0.001
NUM_CLASSES = 8
MODEL_SAVE_PATH = './classifier_autoencoder_tissuemnist.pt'

# Select the dataset
DATA_FLAG = 'tissuemnist'
info = INFO[DATA_FLAG]
DataClass = getattr(medmnist, info['python_class'])

# Transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# Load the dataset
train_dataset = DataClass(split='train', transform=transform, download=True)
val_dataset = DataClass(split='val', transform=transform, download=True)
test_dataset = DataClass(split='test', transform=transform, download=True)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Autoencoder Model
class Autoencoder(nn.Module):
    def __init__(self, latent_dim):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28 * 28),
            nn.Sigmoid(),
            nn.Unflatten(1, (1, 28, 28))
        )

    def forward(self, x):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return reconstructed

    def encode(self, x):
        return self.encoder(x)

# Classifier Model
class Classifier(nn.Module):
    def __init__(self, encoder, latent_dim, num_classes):
        super(Classifier, self).__init__()
        self.encoder = encoder
        self.classifier = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        return self.classifier(encoded)

# Initialize models
autoencoder = Autoencoder(latent_dim=LATENT_DIM)
classifier = Classifier(autoencoder.encoder, LATENT_DIM, num_classes=NUM_CLASSES)

# Loss and optimizer
criterion_autoencoder = nn.MSELoss()
criterion_classifier = nn.CrossEntropyLoss()
optimizer_autoencoder = optim.Adam(autoencoder.parameters(), lr=LEARNING_RATE)
optimizer_classifier = optim.Adam(classifier.parameters(), lr=LEARNING_RATE)

# Training the Autoencoder
def train_autoencoder(model, train_loader, optimizer, criterion, epochs):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, _ in tqdm(train_loader, desc=f'Autoencoder Training Epoch {epoch + 1}/{epochs}'):
            optimizer.zero_grad()
            images = images.float()
            reconstructed = model(images)
            loss = criterion(reconstructed, images)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}')

# Training the Classifier
def train_classifier(model, train_loader, val_loader, optimizer, criterion, epochs):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in tqdm(train_loader, desc=f'Classifier Training Epoch {epoch + 1}/{epochs}'):
            optimizer.zero_grad()
            images = images.float()
            labels = labels.squeeze().long()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {running_loss / len(train_loader):.4f}')
        # Validation
        val_loss, val_acc = validate_classifier(model, val_loader, criterion)
        print(f'Validation Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}')

# Validation function for the Classifier
def validate_classifier(model, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.float()
            labels = labels.squeeze().long()
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    val_loss /= len(val_loader)
    val_acc = accuracy_score(all_labels, all_preds)
    return val_loss, val_acc

# Testing the Classifier
def test_classifier(model, test_loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.float()
            labels = labels.squeeze().long()
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    test_acc = accuracy_score(all_labels, all_preds)
    print(f'Test Accuracy: {test_acc:.4f}')

# Training and evaluation
if __name__ == '__main__':
    train_autoencoder(autoencoder, train_loader, optimizer_autoencoder, criterion_autoencoder, epochs=EPOCHS)
    train_classifier(classifier, train_loader, val_loader, optimizer_classifier, criterion_classifier, epochs=EPOCHS)
    test_classifier(classifier, test_loader)

    # Save the classifier model
    torch.save(classifier.state_dict(), MODEL_SAVE_PATH)
    print(f'Model saved to {MODEL_SAVE_PATH}')

Using downloaded and verified file: C:\Users\metho\.medmnist\tissuemnist.npz
Using downloaded and verified file: C:\Users\metho\.medmnist\tissuemnist.npz
Using downloaded and verified file: C:\Users\metho\.medmnist\tissuemnist.npz


Autoencoder Training Epoch 1/20: 100%|██████████| 2586/2586 [01:11<00:00, 36.35it/s]


Epoch [1/20], Loss: 0.6754


Autoencoder Training Epoch 2/20: 100%|██████████| 2586/2586 [01:24<00:00, 30.78it/s]


Epoch [2/20], Loss: 0.6735


Autoencoder Training Epoch 3/20: 100%|██████████| 2586/2586 [01:28<00:00, 29.10it/s]


Epoch [3/20], Loss: 0.6735


Autoencoder Training Epoch 4/20: 100%|██████████| 2586/2586 [01:24<00:00, 30.78it/s]


Epoch [4/20], Loss: 0.6735


Autoencoder Training Epoch 5/20: 100%|██████████| 2586/2586 [01:23<00:00, 31.02it/s]


Epoch [5/20], Loss: 0.6735


Autoencoder Training Epoch 6/20: 100%|██████████| 2586/2586 [01:23<00:00, 30.91it/s]


Epoch [6/20], Loss: 0.6735


Autoencoder Training Epoch 7/20: 100%|██████████| 2586/2586 [01:24<00:00, 30.44it/s]


Epoch [7/20], Loss: 0.6735


Autoencoder Training Epoch 8/20: 100%|██████████| 2586/2586 [01:24<00:00, 30.55it/s]


Epoch [8/20], Loss: 0.6735


Autoencoder Training Epoch 9/20: 100%|██████████| 2586/2586 [01:25<00:00, 30.31it/s]


Epoch [9/20], Loss: 0.6735


Autoencoder Training Epoch 10/20: 100%|██████████| 2586/2586 [01:24<00:00, 30.74it/s]


Epoch [10/20], Loss: 0.6735


Autoencoder Training Epoch 11/20: 100%|██████████| 2586/2586 [01:24<00:00, 30.54it/s]


Epoch [11/20], Loss: 0.6735


Autoencoder Training Epoch 12/20: 100%|██████████| 2586/2586 [01:25<00:00, 30.32it/s]


Epoch [12/20], Loss: 0.6735


Autoencoder Training Epoch 13/20: 100%|██████████| 2586/2586 [01:25<00:00, 30.41it/s]


Epoch [13/20], Loss: 0.6735


Autoencoder Training Epoch 14/20: 100%|██████████| 2586/2586 [01:26<00:00, 29.81it/s]


Epoch [14/20], Loss: 0.6735


Autoencoder Training Epoch 15/20: 100%|██████████| 2586/2586 [01:31<00:00, 28.28it/s]


Epoch [15/20], Loss: 0.6735


Autoencoder Training Epoch 16/20: 100%|██████████| 2586/2586 [01:26<00:00, 29.99it/s]


Epoch [16/20], Loss: 0.6735


Autoencoder Training Epoch 17/20: 100%|██████████| 2586/2586 [01:31<00:00, 28.36it/s]


Epoch [17/20], Loss: 0.6735


Autoencoder Training Epoch 18/20: 100%|██████████| 2586/2586 [01:32<00:00, 27.93it/s]


Epoch [18/20], Loss: 0.6735


Autoencoder Training Epoch 19/20: 100%|██████████| 2586/2586 [01:34<00:00, 27.39it/s]


Epoch [19/20], Loss: 0.6735


Autoencoder Training Epoch 20/20: 100%|██████████| 2586/2586 [01:36<00:00, 26.86it/s]


Epoch [20/20], Loss: 0.6735


Classifier Training Epoch 1/20: 100%|██████████| 2586/2586 [00:35<00:00, 73.85it/s]


Epoch [1/20], Loss: 1.6288
Validation Loss: 1.5147, Accuracy: 0.4465


Classifier Training Epoch 2/20: 100%|██████████| 2586/2586 [00:40<00:00, 63.58it/s]


Epoch [2/20], Loss: 1.4724
Validation Loss: 1.4246, Accuracy: 0.4822


Classifier Training Epoch 3/20: 100%|██████████| 2586/2586 [00:41<00:00, 62.33it/s]


Epoch [3/20], Loss: 1.4297
Validation Loss: 1.4050, Accuracy: 0.4960


Classifier Training Epoch 4/20: 100%|██████████| 2586/2586 [00:43<00:00, 58.83it/s]


Epoch [4/20], Loss: 1.4211
Validation Loss: 1.4008, Accuracy: 0.4940


Classifier Training Epoch 5/20: 100%|██████████| 2586/2586 [00:58<00:00, 43.86it/s]


Epoch [5/20], Loss: 1.4132
Validation Loss: 1.4194, Accuracy: 0.4772


Classifier Training Epoch 6/20: 100%|██████████| 2586/2586 [01:30<00:00, 28.70it/s]


Epoch [6/20], Loss: 1.4103
Validation Loss: 1.4032, Accuracy: 0.4927


Classifier Training Epoch 7/20: 100%|██████████| 2586/2586 [01:32<00:00, 27.92it/s]


Epoch [7/20], Loss: 1.4054
Validation Loss: 1.4244, Accuracy: 0.4793


Classifier Training Epoch 8/20: 100%|██████████| 2586/2586 [01:32<00:00, 28.01it/s]


Epoch [8/20], Loss: 1.3975
Validation Loss: 1.3859, Accuracy: 0.4971


Classifier Training Epoch 9/20: 100%|██████████| 2586/2586 [01:31<00:00, 28.24it/s]


Epoch [9/20], Loss: 1.3927
Validation Loss: 1.3976, Accuracy: 0.4971


Classifier Training Epoch 10/20: 100%|██████████| 2586/2586 [01:31<00:00, 28.41it/s]


Epoch [10/20], Loss: 1.3857
Validation Loss: 1.4103, Accuracy: 0.4881


Classifier Training Epoch 11/20: 100%|██████████| 2586/2586 [01:31<00:00, 28.23it/s]


Epoch [11/20], Loss: 1.3802
Validation Loss: 1.3776, Accuracy: 0.4974


Classifier Training Epoch 12/20: 100%|██████████| 2586/2586 [01:33<00:00, 27.70it/s]


Epoch [12/20], Loss: 1.3738
Validation Loss: 1.3642, Accuracy: 0.4996


Classifier Training Epoch 13/20: 100%|██████████| 2586/2586 [01:33<00:00, 27.65it/s]


Epoch [13/20], Loss: 1.3697
Validation Loss: 1.3518, Accuracy: 0.5096


Classifier Training Epoch 14/20: 100%|██████████| 2586/2586 [01:02<00:00, 41.23it/s]


Epoch [14/20], Loss: 1.3688
Validation Loss: 1.3692, Accuracy: 0.5032


Classifier Training Epoch 15/20: 100%|██████████| 2586/2586 [00:46<00:00, 56.07it/s]


Epoch [15/20], Loss: 1.3657
Validation Loss: 1.3668, Accuracy: 0.4943


Classifier Training Epoch 16/20: 100%|██████████| 2586/2586 [00:45<00:00, 57.33it/s]


Epoch [16/20], Loss: 1.3604
Validation Loss: 1.3635, Accuracy: 0.5018


Classifier Training Epoch 17/20: 100%|██████████| 2586/2586 [00:45<00:00, 56.69it/s]


Epoch [17/20], Loss: 1.3594
Validation Loss: 1.3850, Accuracy: 0.4914


Classifier Training Epoch 18/20: 100%|██████████| 2586/2586 [00:45<00:00, 56.91it/s]


Epoch [18/20], Loss: 1.3554
Validation Loss: 1.3687, Accuracy: 0.4956


Classifier Training Epoch 19/20: 100%|██████████| 2586/2586 [00:46<00:00, 55.92it/s]


Epoch [19/20], Loss: 1.3515
Validation Loss: 1.3629, Accuracy: 0.5066


Classifier Training Epoch 20/20: 100%|██████████| 2586/2586 [00:46<00:00, 55.19it/s]


Epoch [20/20], Loss: 1.3452
Validation Loss: 1.3622, Accuracy: 0.4917
Test Accuracy: 0.4912
Model saved to ./classifier_autoencoder_tissuemnist.pt
