In [None]:
import glob
from PIL import Image
from torchvision import datasets, transforms
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torch.nn as nn
from torch.nn import functional as F

In [None]:
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

In [None]:
path_to_training_data = 'Dataset/train/'
path_to_validation_data = 'Dataset/validation/'

In [None]:
class CustomDataset(Dataset):
    def __init__(self, path,n_classes=10,transform=False):
        
        self.do_transform = transform
        self.transform = transforms.RandomRotation(180)
        
        self.filelist = glob.glob(path+'/*.png')
        
        self.labels = np.zeros(len(self.filelist))
        
        for class_i in range(n_classes):
            self.labels[ np.array(['class'+str(class_i) in x for x in self.filelist]) ] = class_i
        self.labels = torch.LongTensor(self.labels)
        
        
        
    def __len__(self):
       
        return len(self.filelist)


    def __getitem__(self, idx):
        
        img = Image.open(self.filelist[idx])

        if self.do_transform:
            
            img = self.transform(img)
            
        x = transforms.ToTensor()(img).view(-1)
        
        y = self.labels[idx]
    
        return x, y

In [None]:
from pytorch_lightning.core.lightning import LightningModule

In [None]:
class LitModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(4761, 4761)
        self.layer2 = nn.Linear(4761, 10)
        self.acti = nn.ReLU()
        
    def forward(self, x):
        out = self.acti(self.layer1(x)) 
        out = self.layer2(out)
        
        return out

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        
        return  loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(),lr=2e-05)

    def train_dataloader(self):
        dataset = CustomDataset(path_to_training_data,transform=False)
        loader = DataLoader(dataset, batch_size=120,shuffle=True,num_workers=0)
        return loader
    
    def val_dataloader(self):
        
        dataset = CustomDataset(path_to_validation_data,transform=False)
        loader = DataLoader(dataset, batch_size=120,num_workers=0)
        return loader
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        
        



In [None]:
from pytorch_lightning import Trainer

In [None]:
model = LitModel()

In [None]:
trainer = Trainer(max_epochs=10)

In [None]:
trainer.fit(model)