In [44]:
import numpy as np
import torch
import torchvision.datasets as datasets
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
from array2gif import write_gif


class MNIST_Video(Dataset):
    def __init__(self, root='/scratch1/datasets/mnist', train=True, size=500, nframes=48, windowsize=6):
        mnist_dset = datasets.MNIST(root=root, train=train, download=False, transform=transforms.ToTensor())
        self.sample = np.random.choice(len(mnist_dset),size,replace=False)
        self.dset = torch.stack([mnist_dset[i][0] for i in self.sample])
        
        # The points that decide the class of each video
        self.choicePoints = np.random.choice(nframes, size, replace=True)
        self.classes = np.random.choice(2,size,replace=True)
        self.nframes = nframes
        self.windowsize=windowsize
        self.netframes = nframes-windowsize+1
        #self.background = torch.zeros(1,128,128)
        
        self.startx = 50
        self.inc=2
        #self.starty = 0
        
    def __len__(self):
        return len(self.dset)*self.netframes
        
    def __getitem__(self, index):
        digitIndex = index//self.netframes
        frameIndex = index%self.netframes
        
        digit = self.dset[digitIndex]
        choicePoint = self.choicePoints[digitIndex]
        #choicePoint=24
        endPoint = frameIndex+self.windowsize
        
        yval = self.classes[digitIndex]
        if yval == 0:
            multiplier=-1
        else:
            multiplier=1
        
        outx = self.startx+self.inc*multiplier*np.maximum(
            np.arange(frameIndex-choicePoint, frameIndex-choicePoint+self.windowsize),0)
        outx = np.minimum(np.maximum(outx,0),100)
        outy = self.inc*np.arange(frameIndex,endPoint)
        
        #import pdb; pdb.set_trace() #Debugging?
        pts = torch.zeros(self.windowsize,128,128)
        for i in range(self.windowsize):
            pts[i,outx[i]:outx[i]+28,outy[i]:outy[i]+28] = digit[0,:,:]
        pts = pts + 0.1*torch.randn(pts.size())
        pts = (pts-pts.min())/(pts.max()-pts.min())
            
        return pts, yval
       
    def changeWindowSize(self,newSize):
        self.windowsize = newSize
        self.netframes = self.nframes-newSize+1
    
    def write_gif(self, index, fname):
        backupWindowSize = self.windowsize
        self.changeWindowSize(48)
        pts, _ = self.__getitem__(index)
        pts = pts.unsqueeze(1)
        pts = torch.cat([pts,pts,pts],1)
        pts = pts.numpy()
        write_gif(pts*255, fname, fps=24)
        self.changeWindowSize(backupWindowSize)
        
        
        
        



In [43]:
data = MNIST_Video()
x,y =data[0]
print(x.size(), y)
data.write_gif(0,'test0.gif')
data.write_gif(1,'test1.gif')
data.write_gif(2,'test2.gif')
data.write_gif(3,'test3.gif')
data

torch.Size([6, 128, 128]) 0


In [35]:
torch.Tensor(np.arange(25).reshape((5,5))).max()

tensor(24.)

In [36]:
torch.randn(3,3)

tensor([[ 1.1057,  0.4758,  0.4337],
        [-0.3652,  0.5719,  0.5821],
        [-0.1327, -0.2686,  0.1621]])

In [None]:
mnist_trainset = datasets.MNIST(root='/scratch1/datasets/mnist', train=True, download=False, transform=transforms.ToTensor())
mnist_testset = datasets.MNIST(root='/scratch1/datasets/mnist', train=False, download=False, transform=transforms.ToTensor())

print(len(mnist_trainset), len(mnist_testset), type(mnist_trainset))
x,y = mnist_trainset[5]
print(x.mean(),y)
#print(x.size())