In [1]:
import os
from glob import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.models import swin_v2_t
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader
from torch import nn
from tqdm import tqdm
import pandas as pd
from datasets import SingleImgPillID
from torch.utils.tensorboard import SummaryWriter
from train_test import train, test


class SwinTransformer(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.model = swin_v2_t()
        self.model.head = nn.Linear(in_features=768, out_features=n_classes, bias=True)

    def forward(self,x):
        return self.model(x)




BASE_DATA_DIR = os.path.join("data", "ePillID_data")
SPLITS_DIR = os.path.join(BASE_DATA_DIR,'folds/pilltypeid_nih_sidelbls0.01_metric_5folds/base') 

NUM_EPOCHS = 25
NUM_CLASSES = 4902
BATCH_SIZE = 16

train_encoder = LabelEncoder()
test_encoder = LabelEncoder()

csv_files = glob(os.path.join(SPLITS_DIR,'*.csv'))
all_imgs_df = [x for x in csv_files if x.endswith('all.csv')]
train_files = sorted([x for x in csv_files if not x.endswith('all.csv')])
test_file = train_files.pop(-1)


train_df = []
for i in range(len(train_files)):
    train_df.append(pd.read_csv(train_files[i]))

train_encoder.fit(train_df[0]['label'])

test_df = pd.read_csv(test_file)
test_encoder.fit(test_df['label'])


train_transform = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.4, saturation=0.4, hue=0.4),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


train_dataset = SingleImgPillID(df=train_df[0], label_encoder=train_encoder, transform=train_transform, train=True)
test_dataset = SingleImgPillID(df=test_df, label_encoder=test_encoder, transform=test_transform, train=False)


train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=6)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = SwinTransformer(NUM_CLASSES).to(device)
optimizer = optim.AdamW(model.parameters(), lr= 5e-5)

writer = SummaryWriter()

print('Training Model')
for epoch in range(NUM_EPOCHS):
    train(model, optimizer, train_loader, device, BATCH_SIZE, epoch, writer)
    print()

print('Test')
test(model, optimizer, test_loader, device, BATCH_SIZE, 1, writer)






Training Model


100%|██████████| 47/47 [00:09<00:00,  5.02it/s]


Epoch 0: Loss 7.847196, Accuracy 0.53%



100%|██████████| 47/47 [00:08<00:00,  5.63it/s]


Epoch 1: Loss 6.250442, Accuracy 0.80%



100%|██████████| 47/47 [00:08<00:00,  5.77it/s]


Epoch 2: Loss 5.555408, Accuracy 0.53%



100%|██████████| 47/47 [00:08<00:00,  5.78it/s]


Epoch 3: Loss 5.285556, Accuracy 0.40%



100%|██████████| 47/47 [00:07<00:00,  5.90it/s]


Epoch 4: Loss 5.137386, Accuracy 0.93%



100%|██████████| 47/47 [00:07<00:00,  6.02it/s]


Epoch 5: Loss 5.045146, Accuracy 0.80%



100%|██████████| 47/47 [00:07<00:00,  5.98it/s]


Epoch 6: Loss 4.926218, Accuracy 0.53%



100%|██████████| 47/47 [00:08<00:00,  5.83it/s]


Epoch 7: Loss 4.873803, Accuracy 1.20%



100%|██████████| 47/47 [00:07<00:00,  5.93it/s]


Epoch 8: Loss 4.751190, Accuracy 1.46%



100%|██████████| 47/47 [00:07<00:00,  6.01it/s]


Epoch 9: Loss 4.709505, Accuracy 0.66%



100%|██████████| 47/47 [00:07<00:00,  5.91it/s]


Epoch 10: Loss 4.661379, Accuracy 0.80%



100%|██████████| 47/47 [00:07<00:00,  5.98it/s]


Epoch 11: Loss 4.588043, Accuracy 1.60%



100%|██████████| 47/47 [00:07<00:00,  5.96it/s]


Epoch 12: Loss 4.562721, Accuracy 1.99%



100%|██████████| 47/47 [00:07<00:00,  5.96it/s]


Epoch 13: Loss 4.500560, Accuracy 2.13%



100%|██████████| 47/47 [00:07<00:00,  6.06it/s]


Epoch 14: Loss 4.502749, Accuracy 2.93%



100%|██████████| 47/47 [00:07<00:00,  6.04it/s]


Epoch 15: Loss 4.409604, Accuracy 2.53%



100%|██████████| 47/47 [00:07<00:00,  6.04it/s]


Epoch 16: Loss 4.432275, Accuracy 1.60%



100%|██████████| 47/47 [00:07<00:00,  5.95it/s]


Epoch 17: Loss 4.374491, Accuracy 2.53%



100%|██████████| 47/47 [00:07<00:00,  6.07it/s]


Epoch 18: Loss 4.322076, Accuracy 3.19%



100%|██████████| 47/47 [00:07<00:00,  6.09it/s]


Epoch 19: Loss 4.307204, Accuracy 3.46%



100%|██████████| 47/47 [00:07<00:00,  6.00it/s]


Epoch 20: Loss 4.283809, Accuracy 3.46%



100%|██████████| 47/47 [00:07<00:00,  6.06it/s]


Epoch 21: Loss 4.334257, Accuracy 4.26%



100%|██████████| 47/47 [00:07<00:00,  6.09it/s]


Epoch 22: Loss 4.249318, Accuracy 3.32%



100%|██████████| 47/47 [00:07<00:00,  6.02it/s]


Epoch 23: Loss 4.220207, Accuracy 4.12%



100%|██████████| 47/47 [00:07<00:00,  6.03it/s]


Epoch 24: Loss 4.222216, Accuracy 3.72%

Test


100%|██████████| 47/47 [00:02<00:00, 16.64it/s]

Epoch 1: Loss 6.638869, Accuracy 0.80%



