## Pytorch dataset

In [2]:
from torch.utils.data import Dataset
import torch

In [3]:
class toy_set(Dataset):
    def __init__(self,length=100,transform=None):
        
        self.x = 2*torch.ones(length,2)
        self.y = torch.ones(length,1)
        self.len = length
        self.transform  = transform
        
    def __getitem__(self,index):
        sample = self.x[index] , self.y[index]
        if self.transform:
            sample = self.transform(sample)
        return sample
    
    def __len__(self):
        return self.len
    
dataset= toy_set()

In [5]:
len(dataset)

100

In [6]:
dataset[0]

(tensor([2., 2.]), tensor([1.]))

In [7]:
for i in range(3):
    x , y =dataset[i]
    print(i , 'x:' , x,'y:',y )

0 x: tensor([2., 2.]) y: tensor([1.])
1 x: tensor([2., 2.]) y: tensor([1.])
2 x: tensor([2., 2.]) y: tensor([1.])


## Transforms

In [16]:
class add_mult : 
    def __init__(self,addx=1,muly=1):
        self.addx  = addx
        self.muly = muly
        
    def __call__(self,sample):
        x = sample[0]
        y = sample[1]
        x = x+self.addx
        y = y*self.muly
        sample =x,y
        return sample
    
a_m = add_mult()

x_,y_ =a_m(dataset[0])

In [19]:
dataset = toy_set(transform=a_m)

In [20]:
dataset[0]

(tensor([3., 3.]), tensor([1.]))

## Transforms compose

In [22]:
class mult():
    def __init__(self,mul=100):
        self.mul = mul
        
    def __call__(self,sample):
        x = sample[0]
        y = sample[1]
        x = x*self.mul
        y = y*self.mul
        sample=x,y
        return sample

In [23]:
from torchvision import transforms

data_transform = transforms.Compose([add_mult() , mult()])

In [24]:
data_set_tr = toy_set(transform = data_transform)

In [25]:
data_set_tr[0]

(tensor([300., 300.]), tensor([100.]))