## Vanilla PyTorch training on TinyImageNet dataset (XPU version)

This notebook is intended to show that fixing random seeds leads to the same result in both federated and non-federated cases. 

Also, as a simple example for use Intel® Extension for PyTorch* with the latest performance optimizations for Intel hardware. Intel® Extension for PyTorch* provides easy GPU acceleration for Intel discrete GPUs through the PyTorch* xpu device.

*Please refer to the [Installation Guide](https://intel.github.io/intel-extension-for-pytorch/xpu/2.1.10+xpu/tutorials/installation.html) for the system requirements and steps to install and use Intel® Extension for PyTorch*. For more detailed tutorials and documentations describing features, APIs and technical details, please refer to [Intel® Extension for PyTorch* Documentation](https://intel.github.io/intel-extension-for-pytorch/xpu/2.1.10+xpu/index.html). 

In [None]:
# Check Installation guide and Documentacion of IPEX for additional requiriments
!pip install -r requirements.txt

In [None]:
from pathlib import Path
import os
import shutil
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn
from torch import optim
import torch.nn.functional as F
import torch
import torchvision.transforms as T
import torchvision
import glob
import tqdm
from PIL import Image
import numpy as np

############# code changes ###############
device='xpu'

## Download data

In [None]:
common_data_folder = Path.cwd() / 'data'
zip_file_path = common_data_folder / 'tiny-imagenet-200.zip'
os.makedirs(common_data_folder, exist_ok=True)
os.system(f'wget --no-clobber http://cs231n.stanford.edu/tiny-imagenet-200.zip'
          f' -O {zip_file_path}')
shutil.unpack_archive(str(zip_file_path), str(common_data_folder))

In [None]:
class TinyImageNetDataset(Dataset):
    """TinyImageNet shard dataset class."""

    NUM_IMAGES_PER_CLASS = 500

    def __init__(self, data_folder: Path, data_type='train', transform=None):
        """Initialize TinyImageNetDataset."""
        self.data_type = data_type
        self._common_data_folder = data_folder
        self._data_folder = os.path.join(data_folder, data_type)
        self.labels = {}  # fname - label number mapping
        self.image_paths = sorted(
            glob.iglob(
                os.path.join(self._data_folder, '**', '*.JPEG'),
                recursive=True
            )
        )
        with open(os.path.join(self._common_data_folder, 'wnids.txt'), 'r') as fp:
            self.label_texts = sorted([text.strip() for text in fp.readlines()])
        self.label_text_to_number = {text: i for i, text in enumerate(self.label_texts)}
        self.fill_labels()
        self.transform = transform

    def __len__(self) -> int:
        """Return the len of the shard dataset."""
        return len(self.image_paths)

    def __getitem__(self, index: int):
        """Return an item by the index."""
        file_path = self.image_paths[index]
        sample = self.read_image(file_path)
        if self.transform:
            sample = self.transform(sample)
        label = self.labels[os.path.basename(file_path)]
        return sample, label

    def read_image(self, path: Path):
        """Read the image."""
        img = Image.open(path)
        return img

    def fill_labels(self) -> None:
        """Fill labels."""
        if self.data_type == 'train':
            for label_text, i in self.label_text_to_number.items():
                for cnt in range(self.NUM_IMAGES_PER_CLASS):
                    self.labels[f'{label_text}_{cnt}.JPEG'] = i
        elif self.data_type == 'val':
            with open(os.path.join(self._data_folder, 'val_annotations.txt'), 'r') as fp:
                for line in fp.readlines():
                    terms = line.split('\t')
                    file_name, label_text = terms[0], terms[1]
                    self.labels[file_name] = self.label_text_to_number[label_text]

In [None]:
normalize = T.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

augmentation = T.RandomApply(
    [T.RandomHorizontalFlip(),
     T.RandomRotation(10),
     T.RandomResizedCrop(64)], 
    p=.8
)

training_transform = T.Compose(
    [T.Lambda(lambda x: x.convert("RGB")),
     T.ToTensor(),
     augmentation,
     normalize]
)

valid_transform = T.Compose(
    [T.Lambda(lambda x: x.convert("RGB")),
     T.ToTensor(),
     normalize]
)

In [None]:
def get_train_loader():
    generator=torch.Generator()
    generator.manual_seed(0)
    train_set = TinyImageNetDataset(common_data_folder / 'tiny-imagenet-200', transform=training_transform)
    return DataLoader(train_set, batch_size=64, shuffle=True, generator=generator)

def get_valid_loader():
    valid_set = TinyImageNetDataset(common_data_folder / 'tiny-imagenet-200', data_type='val', transform=valid_transform)
    return DataLoader(valid_set, batch_size=64)

In [None]:
class Net(nn.Module):
    def __init__(self):
        torch.manual_seed(0)
        super(Net, self).__init__()
        self.model = torchvision.models.mobilenet_v2(pretrained=True)
        # self.model.requires_grad_(False)
        self.model.classifier[1] = torch.nn.Linear(in_features=1280, \
                        out_features=200, bias=True)

    def forward(self, x):
        x = self.model.forward(x)
        return x

model = Net()

In [None]:
optimizer = optim.Adam([x for x in model.parameters() if x.requires_grad], lr=1e-4)

In [None]:
loss_fn = F.cross_entropy

In [None]:
############# code changes ###############
import intel_extension_for_pytorch as ipex
model.to(device)
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16)

def train(model, optimizer):
    torch.manual_seed(0)
    
    data_loader = tqdm.tqdm(get_train_loader(), desc="train")
    model.train()
    losses = []

    for data, target in data_loader:
        ############# code changes ###############
        data, target = data.to(device), target.to(device, dtype=torch.int64)
        #data, target = torch.tensor(data).to(device), torch.tensor(
        #    target).to(device)
        optimizer.zero_grad()
        ############# code changes ###############
        with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
            output = model(data)
            loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
        
    return {'train_loss': np.mean(losses),}

def validate(model):
    torch.manual_seed(0)
    model.eval()
    
    data_loader = tqdm.tqdm(get_valid_loader(), desc="validate")
    val_score = 0
    total_samples = 0

    with torch.no_grad():
        for data, target in data_loader:
            samples = target.shape[0]
            total_samples += samples
            ############# code changes ###############
            data, target = data.to(device), target.to(device, dtype=torch.int64)
            #data, target = torch.tensor(data).to(device), \
            #    torch.tensor(target).to(device, dtype=torch.int64)
            ############# code changes ###############
            with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
                output = model(data)
                pred = output.argmax(dim=1,keepdim=True)
            val_score += pred.eq(target).sum().cpu().numpy()
            
    return {'acc': val_score / total_samples,}

In [None]:
for i in range(5):
    if i == 0:
        name, value = next(iter(validate(model).items())) 
        print(f'{name}: {value:f}')
    
    name, value = next(iter(train(model, optimizer).items()))
    print(f'{name}: {value:f}')
    
    name, value = next(iter(validate(model).items())) 
    print(f'{name}: {value:f}')