In [None]:
import os
from glob import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.models import swin_v2_t
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader
from torch import nn
from tqdm import tqdm
import pandas as pd
from datasets import SingleImgPillID
from torch.utils.tensorboard import SummaryWriter
from train_test import train, test


class SwinTransformer(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.model = swin_v2_t()
        self.model.head = nn.Linear(in_features=768, out_features=n_classes, bias=True)

    def forward(self,x):
        return self.model(x)




BASE_DATA_DIR = os.path.join("data", "ePillID_data")
SPLITS_DIR = os.path.join(BASE_DATA_DIR,'folds/pilltypeid_nih_sidelbls0.01_metric_5folds/base') 

NUM_EPOCHS = 25
NUM_CLASSES = 4902
BATCH_SIZE = 16

train_encoder = LabelEncoder()
test_encoder = LabelEncoder()

csv_files = glob(os.path.join(SPLITS_DIR,'*.csv'))
all_imgs_df = [x for x in csv_files if x.endswith('all.csv')]
train_files = sorted([x for x in csv_files if not x.endswith('all.csv')])
test_file = train_files.pop(-1)


train_df = []
for i in range(len(train_files)):
    train_df.append(pd.read_csv(train_files[i]))

train_encoder.fit(train_df[0]['label'])

test_df = pd.read_csv(test_file)
test_encoder.fit(test_df['label'])


train_transform = transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(brightness=0.4, saturation=0.4, hue=0.4),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


train_dataset = SingleImgPillID(df=train_df[0], label_encoder=train_encoder, transform=train_transform, train=True)
test_dataset = SingleImgPillID(df=test_df, label_encoder=test_encoder, transform=test_transform, train=False)


train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, num_workers=6)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=6)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = SwinTransformer(NUM_CLASSES).to(device)
optimizer = optim.AdamW(model.parameters(), lr= 5e-5)

writer = SummaryWriter()

print('Training Model')
for epoch in range(NUM_EPOCHS):
    train(model, optimizer, train_loader, device, BATCH_SIZE, epoch, writer)
    print()

print('Test')
test(model, optimizer, test_loader, device, BATCH_SIZE, 1, writer)




