Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions e2e/integration_tests/backends_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,80 @@ describe(`${SMOKE} backends`, () => {
expect(after).toBe(webglBefore);
});
});

it('can execute op with data from mixed backends', async () => {
const numTensors = tfc.memory().numTensors;
const webglNumDataIds = tfc.findBackend('webgl').numDataIds();
const cpuNumDataIds = tfc.findBackend('cpu').numDataIds();

await tfc.setBackend('cpu');
// This scalar lives in cpu.
const a = tfc.scalar(5);

await tfc.setBackend('webgl');
// This scalar lives in webgl.
const b = tfc.scalar(3);

// Verify that ops can execute with mixed backend data.
tfc.engine().startScope();

await tfc.setBackend('cpu');
const result = tfc.add(a, b);
tfc.test_util.expectArraysClose(await result.data(), [8]);
expect(tfc.findBackend('cpu').numDataIds()).toBe(cpuNumDataIds + 3);

await tfc.setBackend('webgl');
tfc.test_util.expectArraysClose(await tfc.add(a, b).data(), [8]);
expect(tfc.findBackend('webgl').numDataIds()).toBe(webglNumDataIds + 3);

tfc.engine().endScope();

expect(tfc.memory().numTensors).toBe(numTensors + 2);
expect(tfc.findBackend('webgl').numDataIds()).toBe(webglNumDataIds + 2);
expect(tfc.findBackend('cpu').numDataIds()).toBe(cpuNumDataIds);

tfc.dispose([a, b]);

expect(tfc.memory().numTensors).toBe(numTensors);
expect(tfc.findBackend('webgl').numDataIds()).toBe(webglNumDataIds);
expect(tfc.findBackend('cpu').numDataIds()).toBe(cpuNumDataIds);
});

// tslint:disable-next-line: ban
xit('can move complex tensor from cpu to webgl.', async () => {
await tfc.setBackend('cpu');

const real1 = tfc.tensor1d([1]);
const imag1 = tfc.tensor1d([2]);
const complex1 = tfc.complex(real1, imag1);

await tfc.setBackend('webgl');

const real2 = tfc.tensor1d([3]);
const imag2 = tfc.tensor1d([4]);
const complex2 = tfc.complex(real2, imag2);

const result = complex1.add(complex2);

tfc.test_util.expectArraysClose(await result.data(), [4, 6]);
});

// tslint:disable-next-line: ban
xit('can move complex tensor from webgl to cpu.', async () => {
await tfc.setBackend('webgl');

const real1 = tfc.tensor1d([1]);
const imag1 = tfc.tensor1d([2]);
const complex1 = tfc.complex(real1, imag1);

await tfc.setBackend('cpu');

const real2 = tfc.tensor1d([3]);
const imag2 = tfc.tensor1d([4]);
const complex2 = tfc.complex(real2, imag2);

const result = complex1.add(complex2);

tfc.test_util.expectArraysClose(await result.data(), [4, 6]);
});
});
481 changes: 238 additions & 243 deletions e2e/yarn.lock

Large diffs are not rendered by default.

Loading