# TorchUtils
> Some handy utilities for pytorch 

In [None]:
#| default_exp torchutils

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from torch.utils.data import Dataset
import torch

In [None]:
#| export
def device_by_name(name: str) -> torch.device:
    ''' Return reference to cuda device by using Part of it's name

        Args:
            name: part of the cuda device name (shuuld be distinct)

        Return:
            Reference to cuda device

        Updated: Yuval 12/10/19
    '''
    assert torch.cuda.is_available(), "No cuda device"
    device = None
    for i in range(torch.cuda.device_count()):
        dv = torch.device("cuda:{}".format(i))
        if name in torch.cuda.get_device_name(dv):
            device = dv
            break
    assert device, "device {} not found".format(name)
    return device

In [None]:
show_doc(device_by_name)

---

[source](https://github.com/yuval6957/reinautils/blob/main/reinautils/torchutils.py#L11){target="_blank" style="float:right; font-size:smaller"}

### device_by_name

>      device_by_name (name:str)

Return reference to cuda device by using Part of it's name

Args:
    name: part of the cuda device name (shuuld be distinct)

Return:
    Reference to cuda device

Updated: Yuval 12/10/19

#### How to use

In [None]:
#| eval: false
device_by_name("Tesla")

device(type='cuda', index=0)

If the device doesn't exist we should get an error 

In [None]:
#|eval: false
error = False
try:
    device_by_name("fff")
except AssertionError:
    error = True
assert error

In [None]:
#| export
class DatasetCat(Dataset):
    '''
    Concatenate datasets for Pytorch dataloader
    
    The normal pytorch implementation does it only for raws. this is a "column" implementation
    
    Arges:
        datasets: list of datasets, of the same length
        
    Updated: Yuval 12/10/2019
    '''

    def __init__(self, *datasets):
        '''
        Args: datasets - an iterable containing the datasets
        '''
        super(DatasetCat, self).__init__()
        self.datasets=datasets
        assert len(self.datasets)>0
        for dataset in datasets:
            assert len(self.datasets[0])==len(dataset),"Datasets length should be equal"

    def __len__(self):
        return len(self.datasets[0])

    def __getitem__(self, idx):
        outputs = tuple(dataset.__getitem__(idx) for i in self.datasets for dataset in (i if isinstance(i, tuple) else (i,)))
        return tuple(output for i in outputs for output in (i if isinstance(i, tuple) else (i,)))
    

In [None]:
show_doc(DatasetCat)

---

[source](https://github.com/yuval6957/reinautils/blob/main/reinautils/torchutils.py#L33){target="_blank" style="float:right; font-size:smaller"}

### DatasetCat

>      DatasetCat (*datasets)

Concatenate datasets for Pytorch dataloader

The normal pytorch implementation does it only for raws. this is a "column" implementation

Arges:
    datasets: list of datasets, of the same length

Updated: Yuval 12/10/2019

### How to use

This is one dataset

In [None]:
dataset1=torch.utils.data.TensorDataset(torch.ones(5,1),torch.randn(5,1))
print(len(dataset1))
print (dataset1.__getitem__(0))

5
(tensor([1.]), tensor([-1.2270]))


This is the 2nd

In [None]:
dataset2=torch.utils.data.TensorDataset(torch.zeros(5,1),torch.randn(5,1))
print(len(dataset2))
print (dataset2.__getitem__(0))

5
(tensor([0.]), tensor([1.0632]))


And we will concat them row wise 

In [None]:
dataset3 = DatasetCat(dataset1,dataset2)
print(len(dataset3))
print (dataset3.__getitem__(0))
assert dataset3.__getitem__(3) == (*dataset1.__getitem__(3),*dataset2.__getitem__(3))
assert len(dataset3) == len(dataset1) 

5
(tensor([1.]), tensor([-1.2270]), tensor([0.]), tensor([1.0632]))


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()