In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm.notebook import trange

In [282]:
# 64位精度
dtype = torch.float32

mlp = nn.Sequential(
    nn.Linear(2, 256, dtype=dtype),
    nn.LeakyReLU(),
    nn.Linear(256, 256, dtype=dtype),
    nn.LeakyReLU(),
    nn.Linear(256, 256, dtype=dtype),
    nn.LeakyReLU(),
    nn.Linear(256, 1, dtype=dtype),
)

mlp = mlp.to('cuda')

In [419]:
# def sample_func(x, y):
#     return (torch.sin(2 * torch.pi * x) - torch.cos(2 * torch.pi * y) > 0.5).float()
# def sample_func(x):
#     return (torch.sin(0.5 * torch.pi * x)>0.5).float()
def circle_sdf(center, radius, point):
    # 计算点到圆心的距离
    distance_to_center = np.linalg.norm(np.array(point) - np.array(center), axis=1)
    # 计算SDF值
    sdf_value = distance_to_center - radius
    return sdf_value

@torch.no_grad()
def sample_func(x):
    return (x[..., 0] ** 2 + x[..., 1] ** 2 < 0.3).to(dtype)

In [284]:
x = torch.linspace(-1, 1, 10).to('cuda')
y = torch.linspace(-1, 1, 10).to('cuda')
xx, yy = torch.meshgrid(x, y)
xx = xx.reshape(-1, 1).to('cuda')
yy = yy.reshape(-1, 1).to('cuda')
xy = torch.cat([xx, yy], dim=1)
z = sample_func(xy)

In [None]:
ax = plt.figure().add_subplot(projection='3d')
ax.contourf(xx.view(10, 10).cpu().numpy(), yy.view(10, 10).cpu().numpy(), z.view(10, 10).cpu().numpy())
plt.show()

In [None]:
y_pred.shape

In [None]:
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001)
gt_x = 2 * torch.rand(20, 2, device='cuda', dtype=dtype).reshape(-1, 2) - 1
gt_y = sample_func(gt_x)
tbar = trange(10000)
for i in tbar:
    y_pred = mlp(gt_x)
    loss = loss_fn(y_pred, gt_y[..., None])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    tbar.set_description(f'Loss: {loss.item():.4f}')


In [290]:
y_raw = mlp(xy).view(10, 10)
y_sgm = torch.sigmoid(y_raw).detach().cpu().numpy()

In [None]:
ax = plt.figure().add_subplot(projection='3d')
ax.contourf(xx.view(10, 10).cpu().numpy(), yy.view(10, 10).cpu().numpy(), y_sgm)
plt.show()

In [None]:
x = torch.linspace(-0.0000001, 0.0, 100).reshape(-1, 1).to('cuda')
y = torch.linspace(0.5-0.0000001, 0.5, 100).reshape(-1, 1).to('cuda')
x, y = torch.meshgrid(x, y)
x = x.reshape(-1, 1).to('cuda')
y = y.reshape(-1, 1).to('cuda')
detail_xy = torch.cat([x, y], dim=1)

y_pred = mlp(detail_xy).detach().cpu().numpy()
# plt.plot(x.detach().cpu().numpy(), y_pred)

In [None]:
x

In [None]:
resolution = 100
pad = 1
x = torch.linspace(0.0-pad, 0.0+pad, resolution, dtype=dtype).to('cuda')
y = torch.linspace(0.0-pad, 0.0+pad, resolution, dtype=dtype).to('cuda')
xx, yy = torch.meshgrid(x, y)
xx = xx.reshape(-1).to('cuda')
yy = yy.reshape(-1).to('cuda')
xy = torch.stack([xx, yy], dim=1)

y_raw = torch.cos(4*torch.pi*xx) * 0.05 + torch.cos(4*torch.pi*yy) * 0.05 + sample_func(xy) - 0.5
y_sgm = y_raw.view(resolution, resolution).detach().cpu().numpy()

ax = plt.figure().add_subplot(projection='3d')
ax.plot_surface(xx.view(resolution, resolution).cpu().numpy(), yy.view(resolution, resolution).cpu().numpy(), y_sgm, rstride=8, cstride=8, lw=0.5,
                edgecolor='royalblue', alpha=0.3)
ax.contourf(xx.view(resolution, resolution).cpu().numpy(), yy.view(resolution, resolution).cpu().numpy(), y_sgm, [-1.0, -0.2, 0.0, 0.2, 1.0], zdir='z', offset=-2, cmap='viridis')
# 关闭 z 轴label，保留刻度
ax.xaxis.set_ticklabels([])
ax.yaxis.set_ticklabels([])
ax.zaxis.set_ticklabels([])
ax.set_zlim(-2, 1)

plt.show()

In [None]:
def circle_sdf(center, radius, point):
    # 计算点到圆心的距离
    distance_to_center = np.linalg.norm(np.array(point) - np.array(center), axis=1)
    # 计算SDF值
    sdf_value = distance_to_center - radius
    return sdf_value

In [None]:
grid_x, grid_y = np.mgrid[-5:5:100j, -5:5:100j]
xy = np.concatenate([grid_x.reshape(-1, 1), grid_y.reshape(-1, 1)], axis=-1)

In [None]:
sdf0 = circle_sdf((0.8, 0), 4, xy*2).reshape(100, 100)
sdf1 = circle_sdf((0, 2), 4, xy*2).reshape(100, 100)
sdf2 = circle_sdf((0, -2), 4, xy*2).reshape(100, 100)
sdf3 = circle_sdf((0.6, -1), 4, xy*2).reshape(100, 100)
sdf4 = circle_sdf((0.6, 1), 4, xy*2).reshape(100, 100)

sdf = np.min(np.stack([sdf0, sdf1, sdf2, sdf3, sdf4]), axis=0)
# sdf = sdf

In [None]:
kernal = np.ones((3, 3)) / 9
for i in range(30):
    sdf = scipy.ndimage.convolve(sdf, kernal)

In [None]:
import io
buf = io.BytesIO()

plt.figure(figsize=(6,6))
plt.contourf(sdf, origin='lower', cmap='viridis')
plt.contour(sdf, [-2.38], colors =['red'], linestyles='solid', linewidths=5)
plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
# plt.tight_layout()
plt.axis('off')
plt.savefig(buf, format='png')
# plt.adj
# plt.show()

In [None]:
sdf.shape
x = torch.linspace(0.0-pad, 0.0+pad, resolution, dtype=dtype).to('cuda')
y = torch.linspace(0.0-pad, 0.0+pad, resolution, dtype=dtype).to('cuda')

In [None]:
ax = plt.figure().add_subplot(projection='3d')
sigmoid_sdf = torch.sigmoid(torch.tensor((sdf+1)*5)).numpy() + 0.03*np.cos(grid_x*2) + 0.03*np.cos(grid_y*2)
ax.plot_surface(grid_x, grid_y, sigmoid_sdf, rstride=10, cstride=10, lw=0.5,
                edgecolor='royalblue', alpha=0.3, zorder=1)

ax.contourf(grid_x, grid_y, sdf, zdir='z', offset=-0.2, cmap='viridis', zorder=1)
# ax.contour(grid_x, grid_y, sdf, [-2.38], colors =['r'], linestyles='solid', linewidths=5, offset=-1, zorder=2)
# ax.scatter(grid_x[80, 50], grid_y[80, 50], sigmoid_sdf[80, 50], c='red', s=50, zorder=0)
# ax.scatter(grid_x[90, 40], grid_y[90, 40], sigmoid_sdf[90, 40], c='red', s=50, zorder=0)
# direction_x = grid_x[90, 40]-grid_x[80, 50]
# direction_y = grid_y[90, 40]-grid_y[80, 50]
# direction_z = sigmoid_sdf[90, 40]-sigmoid_sdf[80, 50]
# direction_l = np.linalg.norm([direction_x, direction_y, direction_z])
# ax.quiver(grid_x[80, 50], grid_y[80, 50], sigmoid_sdf[80, 50], direction_x, direction_y, direction_z, length=direction_l, color='blue', )
# 关闭 z 轴label，保留刻度
ax.xaxis.set_ticklabels([])
ax.yaxis.set_ticklabels([])
ax.zaxis.set_ticklabels([])
ax.set_zlim(-0.2, 1)

plt.show()

In [None]:
from PIL import Image, ImageDraw
buf.seek(0)  # 移动到字节流的开始位置
im = Image.open(buf)
draw = ImageDraw.Draw(im)

In [None]:
draw.line([(0, 2.5), (im.width, 2.5)], fill='black', width=5)
draw.line([(5, 0), (5, im.height)], fill='black', width=10)
im

In [None]:
x = torch.linspace(-8, 8, 100)
y = torch.sigmoid(x)

In [None]:
plt.figure(figsize=(6, 6), facecolor=[154/255, 201/255, 219/255])
plt.plot(x, y, c=[72/255, 27/255, 109/255], linewidth=10)
plt.axis('off')

In [None]:
y_pred = torch.sigmoid(mlp(x[..., None])).detach().cpu().numpy()
plt.plot(x.detach().cpu().numpy(), y_pred)

In [None]:
y_pred = torch.sigmoid(mlp(x[..., None])).detach().cpu().numpy()
plt.plot(x.detach().cpu().numpy(), y_pred)

In [None]:
x = torch.linspace(-1, 1, 1000).reshape(-1, 1).to('cuda')
y = torch.sin(2 * torch.pi * x).to('cuda')

plt.plot(x.detach().cpu().numpy(), y)
plt.show()