In [1]:
from pathlib import Path
import torch
from tqdm import tqdm

import torchvision
from torchvision import transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import numpy as np


In [22]:
!ls -l ../data/{msl,hirise}

../data/hirise:
total 0
-rw-r--r--   1 tg  staff    0 Apr 24 11:30 _VALID
drwxr-xr-x  10 tg  staff  320 Apr 24 11:22 [1m[36mtest[m[m
drwxr-xr-x  10 tg  staff  320 Apr 24 11:22 [1m[36mtrain[m[m
drwxr-xr-x  10 tg  staff  320 Apr 24 11:29 [1m[36mval[m[m

../data/msl:
total 0
-rw-r--r--   1 tg  staff    0 Apr 24 11:16 _VALID
drwxr-xr-x  21 tg  staff  672 Apr 24 11:16 [1m[36mtest[m[m
drwxr-xr-x  21 tg  staff  672 Apr 24 11:14 [1m[36mtrain[m[m
drwxr-xr-x  21 tg  staff  672 Apr 24 11:16 [1m[36mval[m[m


In [33]:
!ls -1 ../data/msl/train  | wc -l

      19


In [7]:
from torch import nn
from torch import Tensor
import logging as log
import json

log.basicConfig(level=log.INFO)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


"""
Pretrained parent model: place this outside of child module to exclude from its graph 
"""
class PreTrained:
    
    # lazy load
    resnet = torchvision.models.resnext101_32x8d(pretrained=True)
    resnet.eval()


class ImageClassifier(nn.Module):

    def __init__(self, n_classes, pre_classes=1000):
        super().__init__()
        self.fc = nn.Linear(pre_classes, n_classes)
        
    def forward(self, xs: Tensor):
        feats = PreTrained.resnet(xs)
        return self.fc(feats)

model = ImageClassifier(n_classes=19)
optim = torch.optim.Adam(model.parameters(), lr=0.0005, betas=(0.9, 0.999), weight_decay=1e-4)

In [None]:
    
class Trainer:
    
    def __init__(self, work_dir:Path, data:Path, model: nn.Module, optim, device=device):
        self.normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.train_transform = T.Compose([T.RandomResizedCrop(224),
                              T.RandomHorizontalFlip(),
                              T.ToTensor(), self.normalize])
        self.eval_transform = T.Compose([T.Resize(256), T.CenterCrop(224),
                             T.ToTensor(), self.normalize])
        self.train_data = ImageFolder(data / 'train', transform=self.train_transform)
        self.val_data = ImageFolder(data / 'val', transform=self.eval_transform)
        
        self.criterion = nn.CrossEntropyLoss()
        self.model = model.to(device)
        self.optim = optim

        work_dir.mkdir(parents=True, exist_ok=True)
        self.work_dir = work_dir
        self.models_dir = work_dir / 'models'
        
        self._state = dict(step=0, epoch=0, last_checkpt=None, train_loss=[], val_loss=[])
        self._state_file = self.work_dir / 'state.json'
        if self._state_file.exists():
            self._state = json.loads(self._state_file.read_text())
            log.info(f"state={self._state}")
        
        self.step = self._state['step']
        self.epoch = self._state['epoch']
        
        
    def _checkpt_name(self, step, train_loss, val_loss):
        return f'model_{step:6d}_{train_loss:.5f}_{val_loss:.5f}.pkl'

    def _checkpoint(self, train_metrics, val_loader) -> bool:
        with torch.no_grad():
            self.model.eval()
            val_loss, val_acc = self.validate(val_loader)
            train_loss = train_metrics['loss']
            checkpt_path = self.models_dir / self._checkpt_name(self.step, train_loss, val_loss)
            
            state = dict(
                model_state = self.model.state_dict(),
                step = self.step,
                epoch = self.epoch,
                train_stats = train_metrics,
                val_metrics = dict(loss=val_loss, accuracy=val_acc)
            )
            log.info(f"Checkpoint {checkpt_path}")
            torch.save(state, checkpt_path)
            #self.checkpoint(step=self.step, train_loss=train_loss)

        self._state['val_loss'].append(val_loss)
        self._state['train_loss'].append(train_loss)
        
    def validate(self, val_loader):
         for xs, ys in val_loader:
            pass

    def train(self, max_step=10**6, max_epoch=10**3, batch_size=1,
              num_threads=0, checkpoint=1000):
        train_loader = DataLoader(self.train_data, batch_size=batch_size, shuffle=True,
                                  num_workers=num_threads, pin_memory=True)
        val_loader = DataLoader(self.val_data,batch_size=batch_size, shuffle=True,
                                num_workers=num_threads, pin_memory=True)

        
        # todo resume these values from checkpoint
        if self.step > 0:
            log.info(f'resuming from step {self.step}; max_steps:{max_step}')

        force_stop = False  # early stop when val metric goes up
        while not force_stop and self.step <= max_step and self.epoch <= max_epoch:
            train_losses = []
            train_accs = []
            for xs, ys in tqdm(train_loader): 
                self.step += 1
                output = self.model(xs)
                loss = self.criterion(output, ys)

                train_losses.append(loss.item())
                train_accs.append(accuracy(output.data, ys))
                
                  # compute gradient and do SGD step
                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                # 
                if self.step % checkpoint == 0:
                    metrics = dict(loss=np.mean(train_losses), accuracy=np.mean(train_accs))
                    force_stop = self._checkpoint(metrics, val_loader)
                    train_losses.clear()
                    train_accs.clear()
                    break
                if self.step > max_step:
                    log.info("Max steps reached;")
                    break

            if not force_stop and self.step < max_step:
                self.epoch += 1        
                
def accuracy(output, target):
    """Computes accuracy"""
    batch_size = target.size(0)
    _, top_idx = output.max(dim=1)
    correct = top_idx.eq(target).float().sum()
    return 100.0 * correct/batch_size

data = Path('../data/msl/')
work = Path('../tmp.train')
trainer = Trainer(data=data, work_dir=work, model=model, optim=optim)
trainer

In [13]:
trainer.train(batch_size=1)

INFO:root:resuming from step 8; max_steps:1000000
  0%|          | 16/5920 [00:41<4:15:26,  2.60s/it]


KeyboardInterrupt: 