In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

In [None]:
!pip install pytorch_lightning -q

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import glob
import matplotlib.pyplot as plt
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

%matplotlib inline

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))
        


# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
df = pd.read_csv('/kaggle/input/cassava-leaf-disease-classification/train.csv')
df.head()

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision import transforms,models
import pytorch_lightning as pl
import albumentations as A

class ImageClassifier(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = models.resnet34(pretrained=True)
        self.fc = nn.Linear(1000,5)
        
    def forward(self, x):
        embedding = self.model(x)
        out = self.fc(embedding)
        return out
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        z = self.model(x)    
        y_hat = self.fc(z)
        loss = F.cross_entropy(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        z = self.model(x)    
        y_hat = self.fc(z)
        loss = F.cross_entropy(y_hat, y)
        self.log('val_loss', loss)
        return 
        


In [None]:
class LeafDataset():
    def __init__(self,df,aug):
        self.df = df
        self.aug = aug
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img = Image.open('/kaggle/input/cassava-leaf-disease-classification/train_images/' + df.image_id.values[idx])
        img = img.resize((224,224),resample=Image.BILINEAR)
        img = np.array(img)
        label = df.label.values[idx] 
        aug_img = self.aug(image=img)['image']
        aug_img = np.transpose(aug_img,(2,0,1)).astype(np.float32)
        return torch.tensor(aug_img,dtype=torch.float),torch.tensor(label,dtype=torch.long)

In [None]:
# data
tfm = A.Compose(
        [
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225),max_pixel_value=255.0,always_apply=True),
        ])
dataset = LeafDataset(df,tfm)
train_ds, val_ds = random_split(dataset, [15000, 6397])
train_loader = DataLoader(train_ds, batch_size=8,num_workers=4)
val_loader = DataLoader(val_ds, batch_size=8,num_workers=4)
# model
model = ImageClassifier()
# training
trainer = pl.Trainer(max_epochs=3,tpu_cores=8)
#trainer = pl.Trainer(gpus=4, num_nodes=8, precision=16, limit_train_batches=0.5)
trainer.fit(model, train_loader, val_loader)