Skip to content

Commit

Permalink
torch_tensor ad within taichi
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanming-hu committed Oct 9, 2019
1 parent 100a732 commit bc9ea32
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 13 deletions.
File renamed without changes.
77 changes: 77 additions & 0 deletions examples/torch_tensor_ad.py
@@ -0,0 +1,77 @@
import taichi as ti
import numpy as np
import torch

# ti.set_gdb_trigger(True)
ti.cfg.arch = ti.cuda

# n = 1024 * 1024
n = 32

x = ti.var(ti.f32)
y = ti.var(ti.f32)

# https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html

@ti.layout
def values():
# actually useless in thie example
ti.root.dense(ti.i, n).place(x)
ti.root.dense(ti.i, n).place(y)
ti.root.lazy_grad()

@ti.kernel
def torch_kernel():
for i in range(n):
y[i] = x[i] * x[i]


def copy_from(taichi_tensor):
@ti.kernel
def ker(torch_tensor: np.ndarray):
for i in taichi_tensor:
taichi_tensor[i] = torch_tensor[i]

ker.materialize()
return lambda x: ker(x.contiguous())

def copy_to(taichi_tensor):
@ti.kernel
def ker(torch_tensor: np.ndarray):
for i in taichi_tensor:
torch_tensor[i] = taichi_tensor[i]

ker.materialize()
return lambda x: ker(x.contiguous())

x_copy_from = copy_from(x)
y_copy_to = copy_to(y)

y_grad_copy_from = copy_from(y.grad)
x_grad_copy_to = copy_to(x.grad)

class Sqr(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
outp = torch.zeros_like(inp)
x_copy_from(inp)
torch_kernel()
y_copy_to(outp)
return outp

@staticmethod
def backward(ctx, outp_grad):
ti.clear_all_gradients()
inp_grad = torch.zeros_like(outp_grad)

y_grad_copy_from(outp_grad)
torch_kernel.grad()
x_grad_copy_to(inp_grad)

return inp_grad

sqr = Sqr.apply
X = torch.tensor(2 * np.ones((n, ), dtype=np.float32), device=torch.device('cuda:0'), requires_grad=True)
sqr(X).sum().backward()
print(X.grad.cpu())

23 changes: 10 additions & 13 deletions examples/torch_tensor_io.py
Expand Up @@ -8,29 +8,28 @@
# n = 1024 * 1024
n = 32

y = ti.var(ti.i32)
y = ti.var(ti.f32)

# https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html

z = np.array((n,), dtype=np.float32)

@ti.layout
def values():
ti.root.place(y)

# actually useless in thie example
ti.root.dense(ti.i, n).place(y)
ti.root.lazy_grad()

@ti.kernel
def torch_kernel(t: np.ndarray, o: np.ndarray):
for i in range(n):
o[i] = t[i] * t[i]

@ti.kernel
def torch_kernel_2(t_grad: np.ndarray, t:np.ndarray, o_grad: np.ndarray):
for i in range(n):
ti.print(o_grad[i])
t_grad[i] = 2 * t[i] * o_grad[i]


class Sqr(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
Expand All @@ -41,16 +40,14 @@ def forward(ctx, inp):

@staticmethod
def backward(ctx, outp_grad):
print(outp_grad.cpu())
outp_grad = outp_grad.contiguous()
inp_grad = torch.zeros_like(outp_grad)
inp, = ctx.saved_tensors
torch_kernel_2(inp_grad, inp, outp_grad)
return inp_grad

sqr = Sqr.apply
x = torch.tensor(2 * np.ones((n, ), dtype=np.float32), device=torch.device('cuda:0'), requires_grad=True)
sqr(x).sum().backward()
# print(sqr(x).sum())#.backward()
print(x.grad.cpu())
X = torch.tensor(2 * np.ones((n, ), dtype=np.float32), device=torch.device('cuda:0'), requires_grad=True)
sqr(X).sum().backward()
print(X.grad.cpu())

0 comments on commit bc9ea32

Please sign in to comment.