In [1]:
# A notebook to test using streams in pytorch

In [2]:
from time import time

import torch


In [3]:
class NN(torch.nn.Module):
    
    def __init__(self, n):
        
        super().__init__()
        self.ll = torch.nn.Linear(n,n)
        
    def forward(self, x): 
        return torch.sum(torch.exp(self.ll(x)))

In [5]:
n_mats = 100000
d = 10
use_streams = False

t0 = time()
# Initialise cuda tensors here. E.g.:
mats = [torch.rand(d, d, device = 'cuda') for i in range(n_mats)]
nets = [NN(d)]
for nn in nets:
    nn.to('cuda')
t1 = time()
print('Created data in ' + str(t1-t0) + ' seconds.')

t2 = time()
streams = [torch.cuda.Stream() for i in range(n_mats)]
t3 = time()
print('Created streams in ' + str(t1-t0) + ' seconds.')



# Wait for the above tensors to initialise.
torch.cuda.synchronize()
if use_streams:
    t4 = time()
    # Do some basic computations with streams 
    r = [None]*n_mats
    for m_i, m in enumerate(mats):
        with torch.cuda.stream(streams[m_i]):
            r[m_i] = torch.mm(m,m)

    torch.cuda.synchronize()
    t5 = time()
    print('Computed with streams in ' + str(t5-t4) + ' seconds.')
else:
    t6 = time()
    r = [None]*len(mats)
    for m_i, m in enumerate(mats):
        r[m_i] = nn(m)
    torch.cuda.synchronize()
    print(r[0])
    t7 = time()
    print('Computed without streams in ' + str(t7-t6) + ' seconds.')



Created data in 1.2040178775787354 seconds.
Created streams in 1.2040178775787354 seconds.
tensor(104.0897, device='cuda:0', grad_fn=<SumBackward0>)
Computed without streams in 7.664313077926636 seconds.
